Imports

from ssm4epi.models.regional_growth_factor import (
    key,
    n_iterations,
    N_mle,
    N_meis,
    N_posterior,
    percentiles_of_interest,
    make_aux,
    dates_full,
    cases_full,
    n_ij,
    n_tot,
    n_pop,
    account_for_nans,
    growth_factor_model,
)

from tqdm.notebook import tqdm
import pandas as pd

import jax.numpy as jnp
import jax
import jax.random as jrn

from isssm.importance_sampling import prediction
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
    modified_efficient_importance_sampling as MEIS,
)

from pyprojroot.here import here

jax.config.update("jax_enable_x64", True)
from isssm.estimation import initial_theta
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from datetime import date

Data

county_dict = pd.read_csv(here("data/processed/ags_county_dict.csv"))
state_dict = pd.read_csv(here("data/processed/ags_state_dict.csv"))

county_dict["ags_state"] = county_dict["ags"] // 1000
county_dict = county_dict.rename(columns={"name": "name_county", "ags": "ags_county"})
state_dict = state_dict.rename(columns={"ags": "ags_state"})
ags_state_order = pd.merge(county_dict, state_dict, on="ags_state")[
    "ags_state"
].to_numpy()
def incidence_per_state(I_county):
    return jnp.bincount(ags_state_order, I_county)[1:]

Parameters

min_date = date(2020, 10, 17)
max_date = date(2022, 1, 22)
np1 = 10

Dates

initial_dates = [
    (d - pd.Timedelta(weeks=np1)).strftime("%Y-%m-%d")
    for d in pd.date_range(start=min_date, end=max_date, freq="W-SAT", inclusive="both")
]


def initial_to_final_date(initial_date: date) -> date:
    return initial_date + pd.Timedelta(weeks=np1)


final_dates = [
    initial_to_final_date(pd.to_datetime(d)).strftime("%Y-%m-%d") for d in initial_dates
]
len(initial_dates)
67
def make_aux(date, cases_full, n_ij, n_tot, np1):
    final_date = initial_to_final_date(pd.to_datetime(date))
    df_weekly_cases = pd.read_csv(
        here()
        / f"data/processed/RKI_county_{(final_date + pd.Timedelta(days=-5)).strftime('%Y-%m-%d')}.csv"
    ).pivot(index="date", columns="ags", values="cases")

    cases_full = jnp.asarray(df_weekly_cases.to_numpy(), dtype=jnp.float64)
    cases = cases_full[-(np1 + 1) :]
    return cases, n_ij, n_tot
GLOBAL_KEY = jrn.PRNGKey(4534365463653)


def f_pred(x, s, y):
    y_county = y[-1]
    y_state = incidence_per_state(y_county)
    return jnp.concatenate([y_county, y_state])
def create_regional_forecast_dataframe(initial_date):
    with open(
        here() / f"data/results/4_local_outbreak_model/results_{initial_date}.pickle",
        "rb",
    ) as f:
        result = pickle.load(f)

    (theta0, proposal_la, _, dates, _) = result
    (dates_index,) = jnp.where(dates_full == initial_date)[0]
    dates = dates_full[dates_index : dates_index + np1]

    aux = make_aux(initial_date, cases_full, n_ij, n_tot, np1)

    y = aux[0][1:]
    y_nan = y.at[-1].set(jnp.nan)
    missing_inds = jnp.isnan(y_nan)
    _, y_miss = account_for_nans(growth_factor_model(theta0, aux), y_nan, missing_inds)
    _model_miss = lambda theta, aux: account_for_nans(
        growth_factor_model(theta, aux), y_nan, missing_inds
    )[0]
    fitted_model = _model_miss(theta0, aux)

    key, subkey = jrn.split(GLOBAL_KEY)
    preds = prediction(
        f_pred,
        y_miss,
        proposal_la,
        fitted_model,
        1000,
        subkey,
        percentiles_of_interest,
        growth_factor_model(theta0, aux),
    )
    df = pd.DataFrame(
        {
            "variable": [
                *[f"y_county_{c}" for c in range(1, 401)],
                *[f"y_state_{s}" for s in range(1, 17)],
            ],
            "c": [*range(1, 401), *[c for c in range(1, 17)]],
            "t": [
                *jnp.repeat(10, 400),
                *jnp.repeat(10, 16),
            ],
            "mean": preds[0],
            "sd": preds[1],
            **{
                f"{p * 100:.1f} %": preds[2][i, :]
                for i, p in enumerate(percentiles_of_interest)
            },
        }
    )

    df["date"] = [dates[t - 1] for t in df["t"]]
    df.to_csv(
        here(
            f"data/results/4_local_outbreak_model/regional_forecasts_{initial_date}.csv"
        ),
        index=False,
    )
for initial_date in tqdm(initial_dates):
    create_regional_forecast_dataframe(initial_date)